MIRNet Paper Reading

Learning Enriched Features for Real Image Restoration and Enhancement

The MRB whose main branch is dedicated to maintaining spatially-precise high-resolution representations through the entire network and the complementary set of parallel branches provide better contextualized features

主支路采用的是full-size的卷积,在每一个串联的block中,将feat进行下采,然后过DAU,然后进行SKFF。

image|690x328

MRB

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def forward(self, x):
inp = x.clone()
#col 1 only
blocks_out = []
for j in range(self.height):
if j==0:
inp = self.blocks[j][0](inp)
else:
inp = self.blocks[j][0](self.down[f'{inp.size(1)}_{2}'](inp))
blocks_out.append(inp)

#rest of grid
for i in range(1,self.width):
#Mesh
# Replace condition(i%2!=0) with True(Mesh) or False(Plain)
# if i%2!=0:
if True:
tmp=[]
for j in range(self.height):
TENSOR = []
nfeats = (2**j)*self.n_feat
for k in range(self.height):
TENSOR.append(self.select_up_down(blocks_out[k], j, k))

selective_kernel_fusion = self.selective_kernel[j](TENSOR)
tmp.append(selective_kernel_fusion)
#Plain
else:
tmp = blocks_out
#Forward through either mesh or plain
for j in range(self.height):
blocks_out[j] = self.blocks[j][i](tmp[j])

#Sum after grid
out=[]
for k in range(self.height):
out.append(self.select_last_up(blocks_out[k], k))

out = self.selective_kernel[0](out)

out = self.conv_out(out)
out = out + x

return out

每一个MRB看作2维的table,用height 和width表示层数和每层经过的卷积数。值得注意的是每个特征的scale不同,具体fuse方法是在SKFF中。

SKFF:

在输入SKFF之前,每一个SKFF都会选择3个scale的特征,以当前height为例,上层的下采,下层的上采,3个大小一致的输入到SKFF中。

select different scale code
1
2
3
4
5
6
7
8
9
def select_up_down(self, tensor, j, k):
if j==k:
return tensor
else:
diff = 2 ** np.abs(j-k)
if j<k:
return self.up[f'{tensor.size(1)}_{diff}'](tensor)
else:
return self.down[f'{tensor.size(1)}_{diff}'](tensor)

</details>
每三个同大小的feat,预测一个mask然后乘回原来的进行相加。

image|690x229